{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Wide and Deep\n", "\n", "```{note}\n", "Deep部分同Embedding+MLP,Wide部分负责记忆\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 结构\n", "\n", "![jupyter](../images/wide1.jpg)\n", "\n", "左侧是wide部分,右侧是deep部分。\n", "\n", "wide部分:直接把输入层连接到输出层,作用是让模型有较强的记忆力。\n", "\n", "deep部分:典型的embedding + mlp结构,作用是让模型有较强的泛化能力。\n", "\n", "所谓“记忆能力”,即模型直接学习物品或特征的“共现频率”,并把他们直接作为推荐依据。比如说喜欢A电影的也喜欢B这个规则。\n", "\n", "这类规则有两个特点:1.数量非常多;2.非常具体,没必要和其他特征交叉。\n", "\n", "这样我们的Wide&Deep模型就能同时拥有记忆力和泛化能力。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据预处理" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow import keras\n", "import rec\n", "\n", "# 读取movielens数据集\n", "train_dataset, test_dataset = rec.load_movielens()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
movieIduserIdratingtimestamplabelreleaseYearmovieGenre1movieGenre2movieGenre3movieRatingCount...userRatingCountuserAvgReleaseYearuserReleaseYearStddevuserAvgRatinguserRatingStddevuserGenre1userGenre2userGenre3userGenre4userGenre5
01155553.090095374001995AdventureAnimationChildren10759...9219928.983.860.74DramaComedyThrillerActionCrime
11259123.5111163176811995AdventureAnimationChildren10759...21198814.093.481.28ActionComedyRomanceAdventureThriller
21299123.086682036001995AdventureAnimationChildren10759...419950.503.000.00NaNNaNNaNNaNNaN
310176860.5119555501101995ActionAdventureThriller6330...3519928.352.971.48ComedyDramaAdventureActionThriller
4104201584.0115535769111996ComedyNaNNaN3954...8119918.703.600.72ThrillerDramaActionCrimeAdventure
..................................................................
88822968268653.085409223201968HorrorSci-FiThriller1824...94199112.233.350.85DramaThrillerComedyCrimeRomance
8882396885072.097470906101968HorrorSci-FiThriller1824...519940.892.001.00NaNNaNNaNNaNNaN
88824969166895.085785404411951AdventureComedyRomance2380...9719929.953.530.82DramaComedyCrimeRomanceThriller
88825969264602.0125027957601951AdventureComedyRomance2380...55199011.782.731.42ThrillerCrimeDramaComedySci-Fi
8882697030332.0127239460301953AdventureComedyCrime98...100198517.643.670.89DramaRomanceComedyThrillerCrime
\n", "

88827 rows × 27 columns

\n", "
" ], "text/plain": [ " movieId userId rating timestamp label releaseYear movieGenre1 \\\n", "0 1 15555 3.0 900953740 0 1995 Adventure \n", "1 1 25912 3.5 1111631768 1 1995 Adventure \n", "2 1 29912 3.0 866820360 0 1995 Adventure \n", "3 10 17686 0.5 1195555011 0 1995 Action \n", "4 104 20158 4.0 1155357691 1 1996 Comedy \n", "... ... ... ... ... ... ... ... \n", "88822 968 26865 3.0 854092232 0 1968 Horror \n", "88823 968 8507 2.0 974709061 0 1968 Horror \n", "88824 969 16689 5.0 857854044 1 1951 Adventure \n", "88825 969 26460 2.0 1250279576 0 1951 Adventure \n", "88826 970 3033 2.0 1272394603 0 1953 Adventure \n", "\n", " movieGenre2 movieGenre3 movieRatingCount ... userRatingCount \\\n", "0 Animation Children 10759 ... 92 \n", "1 Animation Children 10759 ... 21 \n", "2 Animation Children 10759 ... 4 \n", "3 Adventure Thriller 6330 ... 35 \n", "4 NaN NaN 3954 ... 81 \n", "... ... ... ... ... ... \n", "88822 Sci-Fi Thriller 1824 ... 94 \n", "88823 Sci-Fi Thriller 1824 ... 5 \n", "88824 Comedy Romance 2380 ... 97 \n", "88825 Comedy Romance 2380 ... 55 \n", "88826 Comedy Crime 98 ... 100 \n", "\n", " userAvgReleaseYear userReleaseYearStddev userAvgRating \\\n", "0 1992 8.98 3.86 \n", "1 1988 14.09 3.48 \n", "2 1995 0.50 3.00 \n", "3 1992 8.35 2.97 \n", "4 1991 8.70 3.60 \n", "... ... ... ... \n", "88822 1991 12.23 3.35 \n", "88823 1994 0.89 2.00 \n", "88824 1992 9.95 3.53 \n", "88825 1990 11.78 2.73 \n", "88826 1985 17.64 3.67 \n", "\n", " userRatingStddev userGenre1 userGenre2 userGenre3 userGenre4 \\\n", "0 0.74 Drama Comedy Thriller Action \n", "1 1.28 Action Comedy Romance Adventure \n", "2 0.00 NaN NaN NaN NaN \n", "3 1.48 Comedy Drama Adventure Action \n", "4 0.72 Thriller Drama Action Crime \n", "... ... ... ... ... ... \n", "88822 0.85 Drama Thriller Comedy Crime \n", "88823 1.00 NaN NaN NaN NaN \n", "88824 0.82 Drama Comedy Crime Romance \n", "88825 1.42 Thriller Crime Drama Comedy \n", "88826 0.89 Drama Romance Comedy Thriller \n", "\n", " userGenre5 \n", "0 Crime \n", "1 Thriller \n", "2 NaN \n", "3 Thriller \n", "4 Adventure \n", "... ... \n", "88822 Romance \n", "88823 NaN \n", "88824 Thriller \n", "88825 Sci-Fi \n", "88826 Crime \n", "\n", "[88827 rows x 27 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = rec.get_movielens_df()\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Deep部分\n", "\n", "就像上一节那样处理" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 电影的类别\n", "genre_vocab = ['Film-Noir', 'Action', 'Adventure', 'Horror', 'Romance', 'War', \n", " 'Comedy', 'Western', 'Documentary', 'Sci-Fi', 'Drama', 'Thriller', \n", " 'Crime', 'Fantasy', 'Animation', 'IMAX', 'Mystery', 'Children', 'Musical']\n", "# 类别列\n", "GENRE_FEATURES = {\n", " 'userGenre1': genre_vocab,\n", " 'userGenre2': genre_vocab,\n", " 'userGenre3': genre_vocab,\n", " 'userGenre4': genre_vocab,\n", " 'userGenre5': genre_vocab,\n", " 'movieGenre1': genre_vocab,\n", " 'movieGenre2': genre_vocab,\n", " 'movieGenre3': genre_vocab\n", "}\n", "\n", "categorical_columns = []\n", "for feature, vocab in GENRE_FEATURES.items():\n", " # 先转化为one-hot\n", " cat_col = tf.feature_column.categorical_column_with_vocabulary_list(\n", " key=feature, vocabulary_list=vocab)\n", " # 再转化为embedding,维度是10维\n", " emb_col = tf.feature_column.embedding_column(cat_col, 10)\n", " categorical_columns.append(emb_col)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# movie id embedding feature\n", "movie_col = tf.feature_column.categorical_column_with_identity(key='movieId', num_buckets=1001)\n", "movie_emb_col = tf.feature_column.embedding_column(movie_col, 10)\n", "categorical_columns.append(movie_emb_col)\n", "\n", "# user id embedding feature\n", "user_col = tf.feature_column.categorical_column_with_identity(key='userId', num_buckets=30001)\n", "user_emb_col = tf.feature_column.embedding_column(user_col, 10)\n", "categorical_columns.append(user_emb_col)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# all numerical features\n", "numerical_columns = [tf.feature_column.numeric_column('releaseYear'),\n", " tf.feature_column.numeric_column('movieRatingCount'),\n", " tf.feature_column.numeric_column('movieAvgRating'),\n", " tf.feature_column.numeric_column('movieRatingStddev'),\n", " tf.feature_column.numeric_column('userRatingCount'),\n", " tf.feature_column.numeric_column('userAvgRating'),\n", " tf.feature_column.numeric_column('userRatingStddev')]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Wide部分\n", "\n", "使用两个特征的交叉" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# define input for keras model\n", "inputs = {\n", " 'movieAvgRating': tf.keras.layers.Input(name='movieAvgRating', shape=(), dtype='float32'),\n", " 'movieRatingStddev': tf.keras.layers.Input(name='movieRatingStddev', shape=(), dtype='float32'),\n", " 'movieRatingCount': tf.keras.layers.Input(name='movieRatingCount', shape=(), dtype='int32'),\n", " 'userAvgRating': tf.keras.layers.Input(name='userAvgRating', shape=(), dtype='float32'),\n", " 'userRatingStddev': tf.keras.layers.Input(name='userRatingStddev', shape=(), dtype='float32'),\n", " 'userRatingCount': tf.keras.layers.Input(name='userRatingCount', shape=(), dtype='int32'),\n", " 'releaseYear': tf.keras.layers.Input(name='releaseYear', shape=(), dtype='int32'),\n", "\n", " 'movieId': tf.keras.layers.Input(name='movieId', shape=(), dtype='int32'),\n", " 'userId': tf.keras.layers.Input(name='userId', shape=(), dtype='int32'),\n", " 'userRatedMovie1': tf.keras.layers.Input(name='userRatedMovie1', shape=(), dtype='int32'),\n", "\n", " 'userGenre1': tf.keras.layers.Input(name='userGenre1', shape=(), dtype='string'),\n", " 'userGenre2': tf.keras.layers.Input(name='userGenre2', shape=(), dtype='string'),\n", " 'userGenre3': tf.keras.layers.Input(name='userGenre3', shape=(), dtype='string'),\n", " 'userGenre4': tf.keras.layers.Input(name='userGenre4', shape=(), dtype='string'),\n", " 'userGenre5': tf.keras.layers.Input(name='userGenre5', shape=(), dtype='string'),\n", " 'movieGenre1': tf.keras.layers.Input(name='movieGenre1', shape=(), dtype='string'),\n", " 'movieGenre2': tf.keras.layers.Input(name='movieGenre2', shape=(), dtype='string'),\n", " 'movieGenre3': tf.keras.layers.Input(name='movieGenre3', shape=(), dtype='string'),\n", "}" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "rated_movie = tf.feature_column.categorical_column_with_identity(key='userRatedMovie1',\n", " num_buckets=1001)\n", "# 使用movie_col和rated_movie的交叉作为wide部分的输入\n", "crossed_feature = tf.feature_column.indicator_column(\n", " tf.feature_column.crossed_column([movie_col, rated_movie], 10000))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义模型\n", "\n", "使用keras的函数式API进行定义。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# wide and deep model architecture\n", "# deep part for all input features\n", "deep = tf.keras.layers.DenseFeatures(numerical_columns + categorical_columns)(inputs)\n", "deep = tf.keras.layers.Dense(128, activation='relu')(deep)\n", "deep = tf.keras.layers.Dense(128, activation='relu')(deep)\n", "\n", "# wide part for cross feature\n", "wide = tf.keras.layers.DenseFeatures(crossed_feature)(inputs)\n", "\n", "both = tf.keras.layers.concatenate([deep, wide])\n", "output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(both)\n", "model = tf.keras.Model(inputs, output_layer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 训练" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# compile the model, set loss function, optimizer and evaluation metrics\n", "model.compile(\n", " loss='binary_crossentropy',\n", " optimizer='adam',\n", " metrics=['accuracy', tf.keras.metrics.AUC(curve='ROC'), tf.keras.metrics.AUC(curve='PR')])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/facer/opt/anaconda3/lib/python3.8/site-packages/keras/engine/functional.py:582: UserWarning: Input dict contained keys ['rating', 'timestamp', 'userRatedMovie2', 'userRatedMovie3', 'userRatedMovie4', 'userRatedMovie5', 'userAvgReleaseYear', 'userReleaseYearStddev'] which did not match any model input. They will be ignored by the model.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "7403/7403 [==============================] - 24s 3ms/step - loss: 0.7510 - accuracy: 0.6077 - auc: 0.6272 - auc_1: 0.6638\n", "Epoch 2/5\n", "7403/7403 [==============================] - 20s 3ms/step - loss: 0.6049 - accuracy: 0.6767 - auc: 0.7304 - auc_1: 0.7556\n", "Epoch 3/5\n", "7403/7403 [==============================] - 21s 3ms/step - loss: 0.5482 - accuracy: 0.7214 - auc: 0.7897 - auc_1: 0.8113\n", "Epoch 4/5\n", "7403/7403 [==============================] - 20s 3ms/step - loss: 0.5051 - accuracy: 0.7546 - auc: 0.8270 - auc_1: 0.8471\n", "Epoch 5/5\n", "7403/7403 [==============================] - 20s 3ms/step - loss: 0.4816 - accuracy: 0.7691 - auc: 0.8452 - auc_1: 0.8668\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train the model\n", "model.fit(train_dataset, epochs=5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }